library(tidyverse)
library(cjbart)
library(parallel)
library(haven)
library(pals)
library(sandwich)
library(lmtest)
library(cowplot)
library(xtable)

options(mc.cores = 8)

#### Refresh data ####

source("0_conjoint_functions.R")

conjoint_data <- read_csv("formatted_data/conjoint_data.csv") %>% 
  set_refs()

bin_data <- conjoint_data %>% 
  select(-choice_cont, -weights,
         -ends_with("_sub")) %>%
  filter(!is.na(jobs)) %>% 
  select(-id) %>% 
  mutate(gdp = factor(paste0("gdp_",gdp), 
                      levels = paste0("gdp_",levels(conjoint_data$gdp))),
         jobs = factor(paste0("jobs_", jobs), 
                       levels = paste0("jobs_",levels(conjoint_data$jobs)))) %>% 
  select(-food_indx, -who_indx,
         -starts_with("politics_"),
         -starts_with("health_pol_"),
         -matches("(C|E)[1-9]"),
         -region_merge,
         -attention,
         -system, -POLCONV_VDEM) %>% 
  
  # select(-all_of(c("health_spend_pca","health_compl_pca","intl_flights"))) %>% 
  filter(country != "CHN")

##### TABLE S15 -- Hyperparameter tuning #####

# source("0_hyperparameter_tuning.R")

#### Model training ####

set.seed(89)

# t_mod_0 <- Sys.time()
# cj_mod <- cjbart(bin_data,
#                  Y = "choice_bin",
#                  id = "c_id",
#                  round = "round",
#                  ntree = 750,
#                  cores = 8)
# t_mod_1 <- Sys.time()
# 
# write_rds(cj_mod, paste0("models/conjoint_mod_",Sys.Date(),".RDS"))
# 
# resp_imces <- IMCE(bin_data,
#                    cj_mod,
#                    attribs = c("gdp","jobs","supplies","lockdown","deaths","vaccinated"),
#                    ref_levels = c(gdp = "gdp_0%",
#                                   jobs = "jobs_0%",
#                                   supplies = "Quick to procure",
#                                   lockdown = "10 weeks",
#                                   deaths = "10 per million",
#                                   vaccinated = "5%"),
#                    cores = 8)
# write_rds(resp_imces, paste0("models/conjoint_",Sys.Date(),".RDS"))
# 
# set.seed(89)
# resp_vimp <- het_vimp(resp_imces, cores = 8)
# write_rds(resp_vimp, paste0("models/vimp_",Sys.Date(),".RDS"))

#### IMCE analysis ####
cj_mod <- readRDS("models/conjoint_mod_2024-08-13.RDS")
resp_imces <- readRDS("models/conjoint_2024-08-13.RDS")

macro_levels <- c("30 per million",
                  "50 per million",
                  "70 per million",
                  "90 per million",
                  "-10%",
                  "-5%",
                  "+5%",
                  "+10%",
                  "15%",
                  "25%",
                  "50%",
                  "75%",
                  "20 weeks",
                  "30 weeks",
                  "40 weeks",
                  "Slow to procure")

##### TABLE S2 -- COVARIATE SUMMARY #####
atts_ignore <- c("round", "choice_bin","gdp", "jobs", "supplies", "lockdown", "deaths", "vaccinated")
covar_data <-  bin_data %>% 
  select(-all_of(atts_ignore)) %>% 
  distinct() %>% 
  select(-c_id)

covar_sum <- data.frame(
  covar = {
    covar_data %>%
    colnames()
  }
)

covar_sum$type <- sapply(covar_sum$covar, function (x) {
  if (length(setdiff(covar_data[[x]], c(NA,0,1))) == 0) {
    "binary"
    } else {
      class(covar_data[[x]])
    }
})

covar_sum$missing <- sapply(covar_sum$covar, function (x) 100*sum(is.na(covar_data[[x]]))/length(covar_data[[x]]))

covar_sum$vals <- NA
covar_sum$ct <- NA
covar_sum$sd <- NA
covar_sum$N <- NA
for (i in 1:nrow(covar_sum)) {
  covar_sum$vals[i] <- {
    if (covar_sum$type[i] == "numeric") {
      paste0("[",
             round(min(covar_data[[covar_sum$covar[i]]], na.rm = TRUE),3),
             ", ",
             round(max(covar_data[[covar_sum$covar[i]]], na.rm = TRUE),3),
             "]")
    } else if (covar_sum$type[i] == "binary") {
      "{0,1}"
    } else {
      if (length(unique(covar_data[[covar_sum$covar[i]]])) > 20) {
        paste0("{",
               paste(unique(covar_data[[covar_sum$covar[i]]])[1:5], collapse = " | "),
               " | ...}"
        )
      } else {
        paste0("{",
               paste(unique(covar_data[[covar_sum$covar[i]]]), collapse = " | "),
               "}"
        )
      }
      
    }
  }
  covar_sum$ct[i] <- {
    if (covar_sum$type[i] %in% c("numeric","binary")) {
      paste0(round(mean(covar_data[[covar_sum$covar[i]]], na.rm = TRUE),2))
    } else {
      names(sort(-table(covar_data[[covar_sum$covar[i]]])))[1]
      }
  }
  covar_sum$sd[i] <- {
    if (covar_sum$type[i] %in% c("numeric","binary")) {
      paste0(round(sd(covar_data[[covar_sum$covar[i]]], na.rm = TRUE),2))
    } else {
      "--"
    }
  }
  covar_sum$N[i] <- sum(!is.na(covar_data[[covar_sum$covar[i]]]))
}

covar_sum$surveyq <- case_when(covar_sum$covar == "age" ~ "2.1",
                               covar_sum$covar == "cases_w2" ~ "--",
                               covar_sum$covar == "country" ~ "--",
                               covar_sum$covar == "covidexp_index" ~ "9.4",
                               covar_sum$covar == "covidexp_index_abovemed" ~ "9.4",
                               covar_sum$covar == "deaths_w2" ~ "--",
                               covar_sum$covar == "EDUCATION_LEVEL" ~ "2.3",
                               covar_sum$covar == "eq5d_rate_delta" ~ "9.9-10",
                               covar_sum$covar == "eq5d_rate_now" ~ "9.10",
                               covar_sum$covar == "food_pca" ~ "5.17-41",
                               covar_sum$covar == "gender" ~ "2.2",
                               covar_sum$covar == "hh_inc_delta" ~ "5.13",
                               covar_sum$covar == "hh_inc_obj" ~ "5.12",
                               covar_sum$covar == "ideology" ~ "10.1",
                               covar_sum$covar == "income_abovemed" ~ "5.8",
                               covar_sum$covar == "pop2019_w2" ~ "--",
                               covar_sum$covar == "round" ~ "--",
                               covar_sum$covar == "vac_hes_4_1" ~ "7.4",
                               covar_sum$covar == "vac_hes_7" ~ "7.7",
                               covar_sum$covar == "who_pca" ~ "8.3",
                               covar_sum$covar == "dep_children" ~ "10.11",
                               covar_sum$covar == "marital_status" ~ "10.10",
                               covar_sum$covar == "REGION_0" ~ "2.4",
                               covar_sum$covar == "subj_vaccinated" ~ "6.2",
                               covar_sum$covar == "subj_vacc_refused" ~ "6.1-2",
                               covar_sum$covar == "C1" ~ "OxCGRT: School closures (in days)",
                               covar_sum$covar == "C2" ~ "OxCGRT: Workplace closures (in days)",
                               covar_sum$covar == "C3" ~ "OxCGRT: Public event cancellations (in days)",
                               covar_sum$covar == "C4" ~ "OxCGRT: Gathering limits (in days)",
                               covar_sum$covar == "C5" ~ "OxCGRT: Transport closures (in days)",
                               covar_sum$covar == "C6" ~ "OxCGRT: Shelter-in-place (in days)",
                               covar_sum$covar == "C7" ~ "OxCGRT: Movement restrictions (in days)",
                               covar_sum$covar == "C8" ~ "OxCGRT: International travel restrictions (in days)",
                               covar_sum$covar == "E1" ~ "OxCGRT: Income support (in days)",
                               covar_sum$covar == "E2" ~ "OxCGRT: Debt/contract relief (in days)",
                               covar_sum$covar == "gov_relect" ~ "10.7",
                               covar_sum$covar == "gov_rate" ~ "10.6",
                               covar_sum$covar == "C_pca" ~ "OxCGRT [C1-8]",
                               covar_sum$covar == "E_pca" ~ "OxCGRT [E1-2]",
                               covar_sum$covar == "health_spend_pca" ~ "13.1/17/18",
                               covar_sum$covar == "health_compl_pca" ~ "9.6",
                               covar_sum$covar == "intl_flights" ~ "13.19"
                               
)

covar_sum$covar <- var_format(covar_sum$covar)
covar_sum$type <- case_when(covar_sum$type == "binary" ~ "Binary",
                            covar_sum$type == "character" ~ "Categorical",
                            covar_sum$type == "numeric" ~ "Numeric")

covar_sum %>% 
  select(covar, surveyq, type, missing, vals, ct, sd, N) %>% 
  xtable(., digits = 2) %>% 
  print(include.rownames = FALSE, include.colnames = FALSE,
        only.contents = TRUE,
        hline.after = NULL,
        file = "tables/covar_desc.tex")

##### TABLES S3-6 -- COVAR SUM PER COUNTRY ####
country_sum <- data.frame(
  covar = {
    covar_data %>%
      colnames()
  }
) %>% filter(covar != "country")

country_sum$type <- sapply(country_sum$covar, function (x) {
  if (length(setdiff(covar_data[[x]], c(NA,0,1))) == 0) {
    "binary"
  } else {
    class(covar_data[[x]])
  }
})

for (country in unique(covar_data$country)) {
  country_sum[[country]] <- NA
  country_tmp <- covar_data[covar_data$country == country,]
  for (i in 1:nrow(country_sum)) {
    country_sum[[country]][i] <- {
      if (country_sum$type[i] %in% c("numeric","binary")) {
        paste0(round(mean(country_tmp[[country_sum$covar[i]]], na.rm = TRUE),2),
               " (",
               round(sd(country_tmp[[country_sum$covar[i]]], na.rm = TRUE),2),
               ")")
      } else {
        names(sort(-table(country_tmp[[country_sum$covar[i]]])))[1]
      }
    }
  }
}


country_sum$covar <- var_format(country_sum$covar)
country_sum$type <- NULL

for (g in list(c("AUS","BR","CAN","CHL"),
            c("COL","FR","GHA","IND"),
            c("IT","JPN","SP","UGA"),
            c("UK","US","ZAF"))) {
  xtable(
    country_sum %>% select(all_of(c("covar",g))), digits = 2) %>% 
    print(include.rownames = FALSE, include.colnames = FALSE,
          only.contents = TRUE,
          hline.after = NULL,
          file = paste0("tables/covar_desc_",paste0(g, collapse = "_"),".tex"))
}



##### TABLE S7 -- SUBJECT RECRUITMENT #####
## Produced by hand
#### FIGURE 3 -- IMCE DISTRIBUTION PLOT ####
level_distributions <- resp_imces$imce %>% 
  filter(!duplicated(c_id)) %>% 
  select(all_of(resp_imces$att_levels), c_id) %>% 
  mutate(subj = 1:nrow(.)) %>% 
  pivot_longer(-all_of(c("subj", "c_id")), names_to = "level", values_to = "imce") %>% 
  left_join(resp_imces$att_lookup, by = "level") %>% 
  mutate(Attribute = str_to_sentence(Attribute),
         Attribute = ifelse(Attribute == "Gdp","GDP",Attribute),
         Level = gsub("gdp_","", Level),
         Level = gsub("jobs_","", Level),
         Level = factor(Level, levels = macro_levels)) %>% 
  group_by(level) %>% 
  mutate(median_imce = median(imce, na.rm = TRUE),
         l95_imce = quantile(imce, 0.025),
         u95_imce = quantile(imce, 0.975)) %>% 
  ungroup()
  
library(ggh4x)
dist_scales <- level_distributions %>% 
  group_by(Attribute) %>% 
  summarise(xmin = min(imce) - 0.02,
            xmax = max(imce) + 0.02) %>% 
  rename(Panel = Attribute)

ggplot(level_distributions, aes(x = imce, fill = Attribute)) +
  facet_wrap(~Attribute + Level, dir = "v", scales = "free", nrow = 4) +
  geom_rect(aes(xmin = l95_imce, xmax = u95_imce, ymin = 0, ymax = Inf),
            fill = "grey80") +
  geom_vline(xintercept = 0, color = "indianred", linewidth = 0.4, linetype = "dotted") +
  geom_density(alpha = 0.75) +
  geom_vline(aes(xintercept = median_imce), linetype = "dashed") +
  facetted_pos_scales(
    x = list(
      Attribute == "Deaths" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Deaths", c("xmin","xmax")])
        ),
      Attribute == "GDP" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "GDP", c("xmin","xmax")])
      ),
      Attribute == "Jobs" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Jobs", c("xmin","xmax")])
      ),
      Attribute == "Lockdown" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Lockdown", c("xmin","xmax")])
      ),
      Attribute == "Supplies" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Supplies", c("xmin","xmax")])
      ),
      Attribute == "Vaccinated" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Vaccinated", c("xmin","xmax")])
      )
    )
  ) +
  guides(fill = guide_legend(nrow = 1)) +
  labs(x = "IMCE", y = "") +
  theme(legend.position = "none",
        panel.background = element_blank())

ggsave("figures/figure3.pdf", width = 10, height = 6)

##### TABLE S16 -- SUMMARY OF IMCES #####
imce_tab <- level_distributions %>% 
  group_by(level) %>% 
  mutate(Mean = mean(imce)) %>% 
  ungroup() %>% 
  mutate(Level = factor(Level, levels = macro_levels)) %>% 
  select(Attribute, Level, Mean, median_imce, l95_imce, u95_imce) %>% 
  distinct() %>% 
  arrange(Attribute, Level)

for (i in nrow(imce_tab):2) {
  if (imce_tab$Attribute[i] == imce_tab$Attribute[i-1]) {
    imce_tab$Attribute[i] <- ""
  }
}

xtable(imce_tab, digits = 3) %>% 
  print(only.contents = TRUE, hline.after = c(4,8,12,15,16),
        booktabs = TRUE,
        include.rownames = FALSE, include.colnames = FALSE,
        file = "tables/imce_distributions.tex")


##### FIGURE S22 -- MAIN VIMP RESULTS #####
resp_vimp <- readRDS("models/vimp_2024-08-13.RDS")

vimp_results <- resp_vimp$results %>% 
  mutate(attribute = case_when(Level == "Slow to procure" ~ "Procurement",
                               grepl("jobs_", Level) ~ "Job Growth",
                               grepl("gdp_", Level) ~ "GDP Growth",
                               grepl("per million", Level) ~ "Deaths",
                               grepl("weeks", Level) ~ "Lockdown Length",
                               grepl("%", Level) ~ "% Vaccinated"),
         out_lab = gsub("gdp_|jobs_","",Level),
         out_lab = factor(out_lab, levels = c("30 per million","50 per million","70 per million", "90 per million",
                                              "-10%","-5%","0%","+5%","+10%",
                                              "20 weeks","30 weeks","40 weeks",
                                              "Slow to procure",
                                              "15%","25%","50%","75%")),
         group = var_group(covar),
         group = factor(group,
                        levels = c("(Sub-) national context", "Demographics","Engagement","Economics","Health","Politics")),
         covar = var_format(covar),
         covar = factor(covar, levels = sort(unique(covar)))
         
  ) %>% 
  
  filter(covar != "Round")
  
ggplot(vimp_results, aes(x = covar, y = out_lab, fill = importance*100)) +
  facet_grid(attribute ~ group, space = "free", scales = "free") +
  geom_tile() +
  colorspace::scale_fill_continuous_sequential("Reds3", breaks = seq(0,100,25)) +
  labs(x = "Subject covariate", y = "Attribute-level", fill = "Importance") +
  theme(text = element_text(size = 16),
        legend.position = "right",
        axis.text.x = element_text(angle=45, vjust = 1, hjust = 1, size = 11),
        axis.text.y = element_text(size = 11),
        strip.text.y.right = element_text(angle = 0, face = "bold",
                                          color = "grey40", hjust = 0, size = 11),
        strip.text.x = element_text(color = "grey40", size = 11, face = "bold"),
        strip.background =element_rect(fill=NA),
        panel.border = element_rect(color = "grey40", fill = NA, size = 2),
        
        legend.frame = element_rect(color = "grey40", size = 1))

ggsave("figures/vimp_main.pdf", dpi = 300, width = 14, height = 9)

#### FIGURE 5 -- IMCE distribution along economic and health dimensions ####

IMCE_2d <-level_distributions %>% 
  mutate(dimension = case_when(Attribute %in% c("GDP","Jobs", "Lockdown") ~ "Economic",
                               Attribute %in% c("Deaths","Supplies","Vaccinated") ~ "Health")) %>% 
  group_by(c_id, dimension) %>% 
  summarise(valence = mean(abs(imce))) %>% 
  pivot_wider(names_from = "dimension", values_from = "valence") %>% 
  mutate(type = ifelse(Economic > Health, "E","H")) %>% 
  
  left_join(conjoint_data %>% select(c_id, ideology) %>% distinct(), by = "c_id")
  
ggplot(IMCE_2d, aes(x = Economic, y = Health, color = type)) + 
  geom_point(alpha = 0.15) +
  geom_abline(linetype = "dashed") +
  scale_x_continuous(breaks = c(0,0.03,0.06,0.09)) +
  scale_y_continuous(breaks = c(0,0.03,0.06,0.09,0.12)) +
  labs(x = "Impact of economic signals", y = "Impact of health signals") +
  theme_minimal() +
  theme(legend.position = "none",
        text = element_text(size = 11),
        panel.grid.minor = element_blank())

ggsave("figures/figure5.pdf", dpi = 300, width = 4, height = 4.5)

t.test(IMCE_2d$ideology[IMCE_2d$type == "E"],
       IMCE_2d$ideology[IMCE_2d$type == "H"])

##### TABLE S17 -- RECLASSIFYING LOCKDOWN ATTRIBUTE #####

# Redo calc but treat lockdown as a health attribute
IMCE_2d_lh <- level_distributions %>% 
  mutate(dimension = case_when(Attribute %in% c("GDP","Jobs") ~ "Economic",
                               Attribute %in% c("Deaths","Supplies","Vaccinated","Lockdown") ~ "Health")) %>% 
  group_by(c_id, dimension) %>% 
  summarise(valence = mean(abs(imce))) %>% 
  pivot_wider(names_from = "dimension", values_from = "valence") %>% 
  mutate(type = ifelse(Economic > Health, "E","H")) %>% 
  
  left_join(conjoint_data %>% select(c_id, ideology) %>% distinct(), by = "c_id")

# Redo calc but remove lockdown from calc
IMCE_2d_ln <- level_distributions %>% 
  mutate(dimension = case_when(Attribute %in% c("GDP","Jobs") ~ "Economic",
                               Attribute %in% c("Deaths","Supplies","Vaccinated") ~ "Health")) %>% 
  group_by(c_id, dimension) %>% 
  summarise(valence = mean(abs(imce))) %>% 
  pivot_wider(names_from = "dimension", values_from = "valence") %>% 
  mutate(type = ifelse(Economic > Health, "E","H")) %>% 
  
  left_join(conjoint_data %>% select(c_id, ideology) %>% distinct(), by = "c_id")

ttest_econ <- t.test(IMCE_2d$ideology[IMCE_2d$type == "E"],
                     IMCE_2d$ideology[IMCE_2d$type == "H"])

ttest_health <- t.test(IMCE_2d_lh$ideology[IMCE_2d_lh$type == "E"],
                     IMCE_2d_lh$ideology[IMCE_2d_lh$type == "H"])

ttest_none <- t.test(IMCE_2d_ln$ideology[IMCE_2d_ln$type == "E"],
                       IMCE_2d_ln$ideology[IMCE_2d_ln$type == "H"])

imce_2d_t <- data.frame(
  comparison = c("Lockdown = Economic", "Lockdown = Health", "Lockdown = None"),
  ideology_econ = c(ttest_econ$estimate[1],ttest_health$estimate[1],ttest_none$estimate[1]),
  ideology_health = c(ttest_econ$estimate[2],ttest_health$estimate[2],ttest_none$estimate[2]),
  t = c(ttest_econ$statistic,ttest_health$statistic,ttest_none$statistic),
  p = sapply(c(ttest_econ$p.value,ttest_health$p.value,ttest_none$p.value), function (p) {
    ifelse(p < 0.001, "<0.001", round(p,3))
  })
)


imce_2d_t$diff <- imce_2d_t$ideology_econ - imce_2d_t$ideology_health

imce_2d_t %>% 
  select(comparison, ideology_econ, ideology_health, diff, t, p) %>% 
  xtable(., digits = 3) %>% 
  print(include.rownames = FALSE,
        include.colnames = FALSE,
        only.contents = TRUE,
        hline.after = NULL,
        file = "tables/IMCE_2d_t_comparison.tex")


#### Cluster analysis ####

set.seed(89)

library(cluster)

imce_data <- resp_imces$imce %>% 
  filter(!duplicated(c_id)) %>% 
  select(all_of(resp_imces$att_levels)) %>% 
  mutate(across(.cols = everything(), function (x) scale(x)))

##### FIGURES S24 -- Check k-means criteria #####
set.seed(89)
factoextra::fviz_nbclust(imce_data, kmeans, method='wss', nstart = 25, 
                         iter.max = 100, k.max = 20)

ggsave("figures/wss_by_k.pdf", width = 8, height = 3)

##### FIGURES S26 -- CATEGORICAL DESCRIPTIVES OF SUBJECTS BY CLUSTER #####

set.seed(89)
fit <- kmeans(imce_data, centers = 4, iter.max = 200, nstart = 25)

descriptive_cat <- c("country","gender","hh_inc_delta")

imce_clusters <- resp_imces$imce %>% 
  filter(!duplicated(c_id)) %>% 
  left_join(conjoint_data %>% 
              select(c_id, 
                     starts_with("politics_"),
                     starts_with("health_pol_"),
                     "attention") %>% 
              filter(!duplicated(c_id)),
            by = "c_id")

imce_clusters$cluster <- fit$cluster

cluster_means <- imce_clusters %>% 
  group_by(cluster) %>% 
  summarise(across(starts_with("politics_"), function (x) mean(x, na.rm = TRUE))) %>% 
  t() %>% as.data.frame() %>% 
  mutate(var = rownames(.)) %>% 
  pivot_longer(-var,
               names_prefix = "V",
               names_to = "cluster",
               values_to = "mean")

cluster_ses <- imce_clusters %>% 
  group_by(cluster) %>% 
  summarise(across(starts_with("politics_"), function (x) t.test(x)$stderr)) %>% 
  t() %>% as.data.frame() %>% 
  mutate(var = rownames(.)) %>% 
  pivot_longer(-var,
               names_prefix = "V",
               names_to = "cluster",
               values_to = "se")

cluster_desc_country <- imce_clusters %>% 
  group_by(cluster) %>% 
  mutate(c_n = n()) %>% 
  group_by(cluster, country) %>% 
  summarise(prop = n()/mean(c_n)) %>% 
  
  ggplot(aes(x = prop, fill = country, y = as.factor(cluster))) +
  geom_col(position = position_stack()) +
  scale_fill_manual(values = kelly()[3:17]) +
  labs(x = "", y = "Cluster", fill = "Country") +
  theme_minimal()

cluster_desc_gender <- imce_clusters %>% 
  group_by(cluster) %>% 
  mutate(c_n = n()) %>% 
  group_by(cluster, gender) %>% 
  summarise(prop = n()/mean(c_n)) %>% 
  
  ggplot(aes(x = prop, fill = gender, y = as.factor(cluster))) +
  geom_col(position = position_stack()) +
  scale_fill_manual(values = c(kelly()[3:5],"#848482")) +
  labs(x = "", y = "Cluster", fill = "Gender") +
  theme_minimal()

cluster_desc_hh_inc_delta <- imce_clusters %>% 
  group_by(cluster) %>% 
  mutate(c_n = n(),
         hh_inc_delta = factor(hh_inc_delta, levels = c("Gone up","Stayed the same","Gone down","Prefer not to say","Do not know"))) %>% 
  group_by(cluster, hh_inc_delta) %>% 
  summarise(prop = n()/mean(c_n)) %>% 
  
  ggplot(aes(x = prop, fill = hh_inc_delta, y = as.factor(cluster))) +
  geom_col(position = position_stack()) +
  scale_fill_manual(values = c("#008856","#A1CAF1","#BE0032","#848482","#222222")) +
  labs(x = "% of subjects", y = "Cluster", fill = "Household \nincome change") +
  theme_minimal()

cluster_cats <- plot_grid(
  cluster_desc_country, cluster_desc_gender,cluster_desc_hh_inc_delta, 
  nrow = 3, ncol = 1, rel_heights = c(1.3,1,1)
)

plot(cluster_cats)

ggsave(cluster_cats, filename = "figures/cluster_descriptives_cats.png", dpi=300,
       width = 210, height = 297, units = "mm")

##### FIGURES S27 -- NUMERICAL DESCRIPTIVES OF SUBJECTS BY CLUSTER #####

descriptive_num <- c("age","ideology","food_pca","who_pca",
                     "income_abovemed","covidexp_index_abovemed")

cluster_desc_num <- imce_clusters %>% 
  group_by(cluster) %>% 
  summarise(across(all_of(descriptive_num), ~mean(.x, na.rm = TRUE)),
            n = n()) %>% 
  pivot_longer(-all_of(c("cluster","n")), names_to = "variable", values_to = "mean") %>% 
  left_join(
    {
      imce_clusters %>% 
        group_by(cluster) %>% 
        summarise(across(all_of(descriptive_num), ~sd(.x, na.rm = TRUE))) %>% 
        pivot_longer(-cluster, values_to = "sd")
    },
    by = c("cluster","variable" = "name")) %>% 
  
  mutate(se = ifelse(variable %in% c("income_abovemed","covidexp_index_abovemed"),
                     (mean*(1-mean))/sqrt(n),
                     sd/sqrt(n)),
         
         variable = var_format(variable)
  ) %>% 
  
  ggplot(aes(x = mean, y = variable, color = as.factor(cluster))) +
  facet_wrap(~ variable, scales = "free", ncol = 1) +
  geom_point(position = position_dodge(width = 0.5), size = 2) +
  geom_errorbarh(aes(xmin = mean - 1.96*se, xmax = mean + 1.96*se),
                 position = position_dodge(width = 0.5)) +
  labs(x = "", y = "", color = "Cluster") +
  theme_minimal() +
  theme(axis.text.y = element_blank(),
        legend.position = "bottom",
        strip.text = element_text(face = "bold"))

plot(cluster_desc_num)

ggsave(cluster_desc_num, file = "figures/cluster_descriptives_numeric.png",
       device = "png",
       width = 6, height = 6, dpi = 300)

##### FIGURES S28 -- CLUSTER AVERAGE MARGINAL EFFECTS  #####

imce_cl_means <- imce_clusters %>% 
  group_by(cluster) %>% 
  summarise(across(all_of(resp_imces$att_levels), function (x) mean(x, na.rm = TRUE))) %>% 
  t() %>% as.data.frame() %>% 
  mutate(var = rownames(.)) %>% 
  pivot_longer(-var,
               names_prefix = "V",
               names_to = "cluster",
               values_to = "mean")

imce_cl_ses <- imce_clusters %>% 
  group_by(cluster) %>% 
  summarise(across(all_of(resp_imces$att_levels), function (x) sd(x)/sqrt(length(x)))) %>% 
  t() %>% as.data.frame() %>% 
  mutate(var = rownames(.)) %>% 
  pivot_longer(-var,
               names_prefix = "V",
               names_to = "cluster",
               values_to = "se")

imce_cluster_plot <- imce_cl_means %>% 
  left_join(imce_cl_ses, by = c("var","cluster")) %>% 
  filter(var != "cluster") %>% 
  rename(level = var) %>% 
  left_join(resp_imces$att_lookup, by = "level") %>% 
  mutate(Attribute = str_to_sentence(Attribute),
         Attribute = ifelse(Attribute == "Gdp","GDP",Attribute),
         level = gsub("gdp_","", level),
         level = gsub("jobs_","", level),
         level = factor(level, levels = macro_levels)) %>% 
  ggplot(aes(x = mean, y = level, color = cluster)) + 
  facet_grid(Attribute ~ ., scales = "free_y") +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_point(size = 2.5, alpha = 0.7, position = position_dodge(width = 0.5)) +
  labs(x = "CATE", y = "", color = "Cluster") +
  theme_minimal() +
  theme(legend.position = "bottom",
        strip.text = element_text(face="bold"))

plot(imce_cluster_plot)

ggsave(imce_cluster_plot, file = "figures/CATE_by_cluster.png", width = 8, height = 8, device = "png")



##### FIGURES S29 -- ATTENTION BY CLUSTER #####

cluster_attention <- imce_clusters %>% 
  group_by(cluster) %>%
  summarise(prop = mean(attention),
            c_n = n()) %>% 
  mutate(se = sqrt((prop*(1-prop))/c_n)) %>% 
  
  ggplot(aes(x = cluster, y = prop, fill = as.factor(cluster))) +
  geom_col() +
  labs(x = "Cluster", y = "Proportion with high attention") +
  geom_errorbar(aes(ymin = prop-1.96*se, ymax = prop+1.96*se),
                width = 0.1) +
  scale_y_continuous(breaks = c(0,0.25,0.5)) +
  theme_minimal() +
  theme(legend.position = "none")

plot(cluster_attention)

ggsave(cluster_attention, filename = "figures/attention_by_cluster.pdf", width = 4, height = 4)

#### FIGURE 4 -- IMCE DISTRIBUTION BY CLUSTER ####

imces_cluster_tab <- imce_clusters %>% 
  pivot_longer(cols = all_of(
    resp_vimp$att_lookup$level[resp_vimp$att_lookup$level %in% colnames(imce_clusters)]
  ),
  names_to = "level"
  ) %>% 
  select(c_id, level, value, cluster) %>% 
  left_join(resp_imces$att_lookup, by = "level") %>% 
  mutate(Attribute = str_to_sentence(Attribute),
         Attribute = ifelse(Attribute == "Gdp","GDP",Attribute),
         Level = gsub("gdp_","", Level),
         Level = gsub("jobs_","", Level),
         Level = factor(Level, levels = macro_levels)) %>% 
  group_by(cluster) %>% 
  mutate(Cluster = paste0("#",cluster, " (Subjects =",length(unique(c_id)),")")) %>% 
  ungroup()

imces_cluster_tab %>% 
  filter(Attribute == "Deaths") %>% 
  select(Attribute, Level, Cluster, value) %>% 
  mutate(plot_col = "IMCE",
         Level = paste0(Attribute,": ",Level)) %>% 
  
  rbind(
    imce_clusters %>% 
      select(c_id, cluster, all_of(descriptive_num)) %>% 
      pivot_longer(!any_of(c("cluster", "c_id"))) %>% 
      group_by(cluster) %>% 
      mutate(Cluster = paste0("#",cluster, " (Subjects =",length(unique(c_id)),")")) %>% 
      ungroup() %>% 
      select(-c_id, -cluster) %>% 
      filter(name %in% c("age","ideology","food_pca","who_pca")) %>% 
      mutate(plot_col = "Subject covariate",
             Attribute = "Subject covariate") %>% 
      rename(Level = name) %>% 
      mutate(Level = var_format(Level))
  ) %>% 
  
  ggplot(., aes(x = value, fill = as.factor(Cluster))) +
  facet_wrap(plot_col + Level ~ ., scales = "free", ncol = 2, dir = "v") +
  geom_vline(xintercept = 0, color = "indianred", linewidth = 0.4, linetype = "dotted") +
  geom_density(alpha = 0.5) +
  # theme_minimal() +
  facetted_pos_scales(
    x = list(
      plot_col == "IMCE" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Deaths", c("xmin","xmax")])
      )
    ),
    y = list(
      plot_col == "Subject covariate" & Level == "Food poverty (PCA)" ~ scale_y_continuous(transform = "sqrt")
    )
  ) +
  guides(fill = guide_legend(ncol = 2)) +
  labs(x = "", y = "Density", fill = "Cluster") +
  theme(legend.position = "bottom",
        panel.background = element_blank())

ggsave("figures/figure4.pdf", width = 5, height = 7)


##### FIGURE S25 -- FULL IMCE DISTRIBUTION PLOT BY CLUSTER #####

imces_cluster_tab %>% 
  ggplot(., aes(x = value, fill = as.factor(Cluster))) +
  facet_wrap(Attribute + Level ~ ., scales = "free", nrow = 4, dir = "v") +
  geom_vline(xintercept = 0, color = "indianred", linewidth = 0.4, linetype = "dotted") +
  geom_density(alpha = 0.7) +
  # theme_minimal() +
  facetted_pos_scales(
    x = list(
      Attribute == "Deaths" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Deaths", c("xmin","xmax")])
      ),
      Attribute == "GDP" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "GDP", c("xmin","xmax")])
      ),
      Attribute == "Jobs" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Jobs", c("xmin","xmax")])
      ),
      Attribute == "Lockdown" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Lockdown", c("xmin","xmax")])
      ),
      Attribute == "Supplies" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Supplies", c("xmin","xmax")])
      ),
      Attribute == "Vaccinated" ~ scale_x_continuous(
        limits = unlist(dist_scales[dist_scales$Panel == "Vaccinated", c("xmin","xmax")])
      )
    )
  ) +
  guides(fill = guide_legend(nrow = 1)) +
  labs(x = "IMCE", y = "Density", fill = "Cluster") +
  theme(legend.position = "bottom",
        panel.background = element_blank())

ggsave("figures/imce_cluster_distributions.pdf", width = 10, height = 8)

##### FIGURE S23 -- VIMP MODEL WITH NATIONAL-LEVEL FEATURES #####

# bin_data2 <- read_csv("formatted_data/conjoint_data.csv") %>% 
#   set_refs() %>% 
#   select(-choice_cont, -weights,
#          -ends_with("_sub")) %>%
#   filter(!is.na(jobs)) %>% 
#   select(-id) %>% 
#   mutate(gdp = factor(paste0("gdp_",gdp), 
#                       levels = paste0("gdp_",levels(conjoint_data$gdp))),
#          jobs = factor(paste0("jobs_", jobs), 
#                        levels = paste0("jobs_",levels(conjoint_data$jobs)))) %>% 
#   select(-food_indx, -who_indx,
#          -starts_with("politics_"),
#          -starts_with("health_pol_"),
#          -matches("(C|E)[1-9]"),
#          -region_merge,
#          -attention) %>% 
#   
#   # select(-all_of(c("health_spend_pca","health_compl_pca","intl_flights"))) %>% 
#   filter(country != "CHN")
# 
# set.seed(89)
# t_mod_0 <- Sys.time()
# cj_mod2 <- cjbart(bin_data2,
#                  Y = "choice_bin",
#                  id = "c_id",
#                  round = "round",
#                  ntree = 750,
#                  cores = 8)
# t_mod_1 <- Sys.time()
# 
# write_rds(cj_mod2, paste0("models/conjoint_mod_",Sys.Date(),"_INC_REGIME.RDS"))
# 
# resp_imces2 <- IMCE(bin_data2,
#                    cj_mod2,
#                    attribs = c("gdp","jobs","supplies","lockdown","deaths","vaccinated"),
#                    ref_levels = c(gdp = "gdp_0%",
#                                   jobs = "jobs_0%",
#                                   supplies = "Quick to procure",
#                                   lockdown = "10 weeks",
#                                   deaths = "10 per million",
#                                   vaccinated = "5%"),
#                    cores = 8)
# write_rds(resp_imces2, paste0("models/conjoint_",Sys.Date(),"_INC_REGIME.RDS"))
# 
# set.seed(89)
# resp_vimp2 <- het_vimp(resp_imces2, cores = 8)
# write_rds(resp_vimp2, paste0("models/vimp_",Sys.Date(),"_INC_REGIME.RDS"))

resp_vimp2 <- readRDS("models/vimp_2024-08-13_INC_REGIME.RDS")

vimp_results2 <- resp_vimp2$results %>% 
  mutate(attribute = case_when(Level == "Slow to procure" ~ "Procurement",
                               grepl("jobs_", Level) ~ "Job Growth",
                               grepl("gdp_", Level) ~ "GDP Growth",
                               grepl("per million", Level) ~ "Deaths",
                               grepl("weeks", Level) ~ "Lockdown Length",
                               grepl("%", Level) ~ "% Vaccinated"),
         out_lab = gsub("gdp_|jobs_","",Level),
         out_lab = factor(out_lab, levels = c("30 per million","50 per million","70 per million", "90 per million",
                                              "-10%","-5%","0%","+5%","+10%",
                                              "20 weeks","30 weeks","40 weeks",
                                              "Slow to procure",
                                              "15%","25%","50%","75%")),
         group = var_group(covar),
         group = factor(group,
                        levels = c("(Sub-) national context", "Demographics","Engagement","Economics","Health","Politics")),
         covar = var_format(covar),
         covar = factor(covar, levels = sort(unique(covar)))
         
  ) %>% 
  
  filter(covar != "Round")

ggplot(vimp_results2, aes(x = covar, y = out_lab, fill = importance*100)) +
  facet_grid(attribute ~ group, space = "free", scales = "free") +
  geom_tile() +
  colorspace::scale_fill_continuous_sequential("Reds3", breaks = seq(0,100,25)) +
  labs(x = "Subject covariate", y = "Attribute-level", fill = "Importance") +
  theme(text = element_text(size = 16),
        legend.position = "right",
        axis.text.x = element_text(angle=45, vjust = 1, hjust = 1, size = 11),
        axis.text.y = element_text(size = 11),
        strip.text.y.right = element_text(angle = 0, face = "bold",
                                          color = "grey40", hjust = 0, size = 11),
        strip.text.x = element_text(color = "grey40", size = 11, face = "bold"),
        strip.background =element_rect(fill=NA),
        panel.border = element_rect(color = "grey40", fill = NA, size = 2),
        
        legend.frame = element_rect(color = "grey40", size = 1))

ggsave("figures/vimp_heatmap_INC_REGIME.pdf", dpi = 300, width = 13, height = 9)
